{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 기본 그래프 생성\n",
    "\n",
    "이번 튜토리얼에서는 LangGraph를 사용하여 그래프를 생성하는 방법을 배웁니다.\n",
    "\n",
    "LangGraph 의 그래프를 정의하기 위해서는\n",
    "\n",
    "1. State 정의\n",
    "2. 노드 정의\n",
    "3. 그래프 정의\n",
    "4. 그래프 컴파일\n",
    "5. 그래프 시각화\n",
    "\n",
    "단계를 거칩니다.\n",
    "\n",
    "그래프 생성시 조건부 엣지를 사용하는 방법과 다양한 흐름 변경 방법을 알아봅니다.\n",
    "\n",
    "![langgraph-building-graphs](assets/langgraph-building-graphs.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State 정의"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import TypedDict, Annotated, List\n",
    "from langchain_core.documents import Document\n",
    "import operator\n",
    "\n",
    "\n",
    "# State 정의\n",
    "class GraphState(TypedDict):\n",
    "    context: Annotated[List[Document], operator.add]\n",
    "    answer: Annotated[List[Document], operator.add]\n",
    "    question: Annotated[str, \"user question\"]\n",
    "    sql_query: Annotated[str, \"sql query\"]\n",
    "    binary_score: Annotated[str, \"binary score yes or no\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 노드 정의"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieve(state: GraphState) -> GraphState:\n",
    "    # retrieve: 검색\n",
    "    documents = \"검색된 문서\"\n",
    "    return {\"context\": documents}\n",
    "\n",
    "\n",
    "def rewrite_query(state: GraphState) -> GraphState:\n",
    "    # Query Transform: 쿼리 재작성\n",
    "    documents = \"검색된 문서\"\n",
    "    return GraphState(context=documents)\n",
    "\n",
    "\n",
    "def llm_gpt_execute(state: GraphState) -> GraphState:\n",
    "    # LLM 실행\n",
    "    answer = \"GPT 생성된 답변\"\n",
    "    return GraphState(answer=answer)\n",
    "\n",
    "\n",
    "def llm_claude_execute(state: GraphState) -> GraphState:\n",
    "    # LLM 실행\n",
    "    answer = \"Claude 의 생성된 답변\"\n",
    "    return GraphState(answer=answer)\n",
    "\n",
    "\n",
    "def relevance_check(state: GraphState) -> GraphState:\n",
    "    # Relevance Check: 관련성 확인\n",
    "    binary_score = \"Relevance Score\"\n",
    "    return GraphState(binary_score=binary_score)\n",
    "\n",
    "\n",
    "def sum_up(state: GraphState) -> GraphState:\n",
    "    # sum_up: 결과 종합\n",
    "    answer = \"종합된 답변\"\n",
    "    return GraphState(answer=answer)\n",
    "\n",
    "\n",
    "def search_on_web(state: GraphState) -> GraphState:\n",
    "    # Search on Web: 웹 검색\n",
    "    documents = state[\"context\"] = \"기존 문서\"\n",
    "    searched_documents = \"검색된 문서\"\n",
    "    documents += searched_documents\n",
    "    return GraphState(context=documents)\n",
    "\n",
    "\n",
    "def get_table_info(state: GraphState) -> GraphState:\n",
    "    # Get Table Info: 테이블 정보 가져오기\n",
    "    table_info = \"테이블 정보\"\n",
    "    return GraphState(context=table_info)\n",
    "\n",
    "\n",
    "def generate_sql_query(state: GraphState) -> GraphState:\n",
    "    # Make SQL Query: SQL 쿼리 생성\n",
    "    sql_query = \"SQL 쿼리\"\n",
    "    return GraphState(sql_query=sql_query)\n",
    "\n",
    "\n",
    "def execute_sql_query(state: GraphState) -> GraphState:\n",
    "    # Execute SQL Query: SQL 쿼리 실행\n",
    "    sql_result = \"SQL 결과\"\n",
    "    return GraphState(context=sql_result)\n",
    "\n",
    "\n",
    "def validate_sql_query(state: GraphState) -> GraphState:\n",
    "    # Validate SQL Query: SQL 쿼리 검증\n",
    "    binary_score = \"SQL 쿼리 검증 결과\"\n",
    "    return GraphState(binary_score=binary_score)\n",
    "\n",
    "\n",
    "def handle_error(state: GraphState) -> GraphState:\n",
    "    # Error Handling: 에러 처리\n",
    "    error = \"에러 발생\"\n",
    "    return GraphState(context=error)\n",
    "\n",
    "\n",
    "def decision(state: GraphState) -> GraphState:\n",
    "    # 의사결정\n",
    "    decision = \"결정\"\n",
    "    # 로직을 추가할 수 가 있고요.\n",
    "\n",
    "    if state[\"binary_score\"] == \"yes\":\n",
    "        return \"종료\"\n",
    "    else:\n",
    "        return \"재검색\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 그래프 정의"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END, StateGraph\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langchain_teddynote.graphs import visualize_graph\n",
    "\n",
    "# (1): Conventional RAG\n",
    "# (2): 재검색\n",
    "# (3): 멀티 LLM\n",
    "# (4): 쿼리 재작성\n",
    "\n",
    "\n",
    "# langgraph.graph에서 StateGraph와 END를 가져옵니다.\n",
    "workflow = StateGraph(GraphState)\n",
    "\n",
    "# 노드를 추가합니다.\n",
    "workflow.add_node(\"retrieve\", retrieve)\n",
    "\n",
    "# workflow.add_node(\"rewrite_query\", rewrite_query)  # (4)\n",
    "\n",
    "workflow.add_node(\"GPT 요청\", llm_gpt_execute)\n",
    "# workflow.add_node(\"Claude 요청\", llm_claude_execute)  # (3)\n",
    "workflow.add_node(\"GPT_relevance_check\", relevance_check)\n",
    "# workflow.add_node(\"Claude_relevance_check\", relevance_check)  # (3)\n",
    "workflow.add_node(\"결과 종합\", sum_up)\n",
    "\n",
    "# 각 노드들을 연결합니다.\n",
    "workflow.add_edge(\"retrieve\", \"GPT 요청\")\n",
    "# workflow.add_edge(\"retrieve\", \"Claude 요청\")  # (3)\n",
    "# workflow.add_edge(\"rewrite_query\", \"retrieve\")  # (4)\n",
    "workflow.add_edge(\"GPT 요청\", \"GPT_relevance_check\")\n",
    "workflow.add_edge(\"GPT_relevance_check\", \"결과 종합\")\n",
    "# workflow.add_edge(\"Claude 요청\", \"Claude_relevance_check\")  # (3)\n",
    "# workflow.add_edge(\"Claude_relevance_check\", \"결과 종합\")  # (3)\n",
    "\n",
    "workflow.add_edge(\"결과 종합\", END)  # (2) - off\n",
    "\n",
    "# 조건부 엣지를 추가합니다. (2), (4)\n",
    "# workflow.add_conditional_edges(\n",
    "#     \"결과 종합\",  # 관련성 체크 노드에서 나온 결과를 is_relevant 함수에 전달합니다.\n",
    "#     decision,\n",
    "#     {\n",
    "#         \"재검색\": \"retrieve\",  # 관련성이 있으면 종료합니다.\n",
    "#         \"종료\": END,  # 관련성 체크 결과가 모호하다면 다시 답변을 생성합니다.\n",
    "#     },\n",
    "# )\n",
    "\n",
    "# 조건부 엣지를 추가합니다. (4)\n",
    "# workflow.add_conditional_edges(\n",
    "#     \"결과 종합\",  # 관련성 체크 노드에서 나온 결과를 is_relevant 함수에 전달합니다.\n",
    "#     decision,\n",
    "#     {\n",
    "#         \"재검색\": \"rewrite_query\",  # 관련성이 있으면 종료합니다.\n",
    "#         \"종료\": END,  # 관련성 체크 결과가 모호하다면 다시 답변을 생성합니다.\n",
    "#     },\n",
    "# )\n",
    "\n",
    "# 시작점을 설정합니다.\n",
    "workflow.set_entry_point(\"retrieve\")\n",
    "\n",
    "# 기록을 위한 메모리 저장소를 설정합니다.\n",
    "memory = MemorySaver()\n",
    "\n",
    "# 그래프를 컴파일합니다.\n",
    "app = workflow.compile(checkpointer=memory)\n",
    "\n",
    "# 그래프 시각화\n",
    "visualize_graph(app)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SQL RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from langgraph.graph import END, StateGraph\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "# langgraph.graph에서 StateGraph와 END를 가져옵니다.\n",
    "workflow = StateGraph(GraphState)\n",
    "\n",
    "# 노드를 추가합니다.\n",
    "workflow.add_node(\"질문\", retrieve)\n",
    "workflow.add_node(\"rewrite_query\", rewrite_query)\n",
    "workflow.add_node(\"rewrite_question\", rewrite_query)\n",
    "workflow.add_node(\"GPT 요청\", llm_gpt_execute)\n",
    "workflow.add_node(\"GPT_relevance_check\", relevance_check)\n",
    "workflow.add_node(\"결과 종합\", sum_up)\n",
    "workflow.add_node(\"get_table_info\", get_table_info)\n",
    "workflow.add_node(\"generate_sql_query\", generate_sql_query)\n",
    "workflow.add_node(\"execute_sql_query\", execute_sql_query)\n",
    "workflow.add_node(\"validate_sql_query\", validate_sql_query)\n",
    "\n",
    "# 각 노드들을 연결합니다.\n",
    "workflow.add_edge(\"질문\", \"get_table_info\")\n",
    "workflow.add_edge(\"get_table_info\", \"generate_sql_query\")\n",
    "workflow.add_edge(\"generate_sql_query\", \"execute_sql_query\")\n",
    "workflow.add_edge(\"execute_sql_query\", \"validate_sql_query\")\n",
    "\n",
    "workflow.add_conditional_edges(\n",
    "    \"validate_sql_query\",\n",
    "    decision,\n",
    "    {\n",
    "        \"QUERY ERROR\": \"rewrite_query\",\n",
    "        \"UNKNOWN MEANING\": \"rewrite_question\",\n",
    "        \"PASS\": \"GPT 요청\",\n",
    "    },\n",
    ")\n",
    "\n",
    "workflow.add_edge(\"rewrite_query\", \"execute_sql_query\")\n",
    "workflow.add_edge(\"rewrite_question\", \"rewrite_query\")\n",
    "workflow.add_edge(\"GPT 요청\", \"GPT_relevance_check\")\n",
    "workflow.add_edge(\"GPT_relevance_check\", \"결과 종합\")\n",
    "workflow.add_edge(\"결과 종합\", END)\n",
    "\n",
    "# 시작점을 설정합니다.\n",
    "workflow.set_entry_point(\"질문\")\n",
    "\n",
    "# 기록을 위한 메모리 저장소를 설정합니다.\n",
    "memory = MemorySaver()\n",
    "\n",
    "# 그래프를 컴파일합니다.\n",
    "app = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "그래프를 시각화 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.graphs import visualize_graph\n",
    "\n",
    "visualize_graph(app)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py-test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
